Set random seed for reproducibility
set.seed(1234)
library(dplyr)
library(lubridate)
library(lime) # ML local interpretation
library(vip) # ML global interpretation
library(pdp) # ML global interpretation
library(ggplot2) # visualization pkg leveraged by above packages
library(caret) # ML model building
Read in data
all.df <- read.csv("./data/all.df.csv")
Convert dates
all.df$dot <- ymd(all.df$dot)
all.df$dor <- ymd(all.df$dor)
all.df$bdate <- ymd(all.df$bdate)
all.df$pdate <- ymd(all.df$pdate)
Convert all character strings to factors
all.df <- all.df %>% mutate_if(is.character,as.factor)
Make outcome a binary variable (0/1 relapse)
all.df$rbin <- factor(all.df$rbin, levels = c("yes", "no"))
Filter out any tests that are post-relapse
all.df <- all.df[which(all.df$bdate < all.df$dor | is.na(all.df$dor)), ]
Filter out relapse >720 days
all.df <- all.df[which(all.df$rbin == "no" | all.df$rtime < 720),]
Filter out any missing tests
all.df <- all.df[!is.na(all.df$bmc_cdw) & !is.na(all.df$bmc_cd3) &
!is.na(all.df$bmc_cd15) & !is.na(all.df$bmc_cd34) &
!is.na(all.df$pbc_cdw) & !is.na(all.df$pbc_cd3) &
!is.na(all.df$pbc_cd15) & !is.na(all.df$pbc_cd34),]
all.df <<- all.df
Sub out the required data
rbin ~ txage + sex + rstatprtx + hla + tbi + gmgp + agvhd + cgvhd + bmc_cdw + bmc_cd3 + bmc_cd15 + bmc_cd34 + pbc_cdw + pbc_cd3 + pbc_cd15 + pbc_cd34
all.df2 <<- all.df %>%
select(rbin, txage, sex, rstatprtx, hla,
tbi, gmgp, agvhd, cgvhd, ## Removed e as only 1 level
bmc_cdw, bmc_cd3, bmc_cd15, bmc_cd34,
pbc_cdw, pbc_cd3, pbc_cd15, pbc_cd34)
# fit.caret <- train(
# rbin ~ .,
# data = all.df2,
# method = 'rf',
# trControl = trainControl(method = "cv", number = 5, classProbs = TRUE),
# tuneLength = 1
# )
fit.caret <- train(
rbin ~ .,
data = all.df2,
method = 'rf'
)
fit.caret
## Random Forest
##
## 140 samples
## 16 predictor
## 2 classes: 'yes', 'no'
##
## No pre-processing
## Resampling: Bootstrapped (25 reps)
## Summary of sample sizes: 140, 140, 140, 140, 140, 140, ...
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 2 0.8799954 0.5983077
## 10 0.9034804 0.7030361
## 19 0.9097140 0.7297991
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 19.
Optional rf model –> probably not needed
fit.rf <- randomForest::randomForest(
rbin ~ .,
data = all.df2)
vip(fit.rf, method = "permute", data = all.df2, target = "rbin",
metric = "auc", pred_wrapper = predict,
reference_class = "no", nsim = 10) + ggtitle("ranger: RF")
# vis <- vi(fit.ranger, method = "permute", data = all.df2, target = "rbin",
# metric = "auc", pred_wrapper = pfun,
# reference_class = "no", nsim = 100)
# vip(vis, geom = "boxplot") # Figure 12
vis <- vi(fit.rf, method = "permute", data = all.df2, target = "rbin",
metric = "auc", pred_wrapper = predict,
reference_class = "no", nsim = 100)
vip(vis, geom = "boxplot") # Figure 12
p <- ggplot(vis, aes(Variable, Importance))
p +
geom_bar(stat="identity", color="black",
position=position_dodge()) +
geom_errorbar(aes(ymin = Importance-StDev,
ymax = Importance+StDev), width = 0.2) +
coord_flip()
p <- ggplot(vis, aes(reorder(Variable, Importance), Importance)) +
geom_bar(stat="identity", color="black",
position=position_dodge()) +
geom_errorbar(aes(ymin = Importance-StDev,
ymax = Importance+StDev), width = 0.2) +
coord_flip() + theme_bw() + scale_x_discrete(name = "Variable")
print(p)
#ggsave("./lime_plots/vip.pdf", plot = p)
explainer_caret <- lime(all.df2, fit.caret, n_bins = 5)
# explainer_rf <- lime(all.df2, fit.rf, n_bins = 5)
summary(explainer_caret)
## Length Class Mode
## model 24 train list
## preprocess 1 -none- function
## bin_continuous 1 -none- logical
## n_bins 1 -none- numeric
## quantile_bins 1 -none- logical
## use_density 1 -none- logical
## feature_type 17 -none- character
## bin_cuts 17 -none- list
## feature_distribution 17 -none- list
# summary(explainer_rf)
Example explainer plot for patient 1
patientID <- which(all.df$ID == 1)
explanation_caret <- explain(
x = all.df2[patientID,],
explainer = explainer_caret,
n_permutations = 5000,
dist_fun = "gower",
kernel_width = .75,
n_features = 10,
feature_select = "highest_weights",
labels = "yes"
)
p1 <- plot_features(explanation_caret)
plot_explanations(explanation_caret)
all_patients = unique(all.df$ID)
for (i in 1:length(all_patients)) {
patientID <- which(all.df$ID == all_patients[i])
explanation_caret <- explain(
x = all.df2[patientID,],
explainer = explainer_caret,
n_permutations = 5000,
dist_fun = "gower",
kernel_width = .75,
n_features = 10,
feature_select = "highest_weights",
labels = "yes"
)
p1 <- plot_features(explanation_caret) + ggtitle(paste("Patient", all_patients[i]))
print(p1)
#ggsave(paste0("./lime_plots/patient_",i,".pdf"), plot = p1)
}
Patient 8 (ID = 5)
j = 5
patientID <- which(all.df$ID == j)
## Print one line of table to check
all.df[patientID[1], ]
## ID dot dor txage relapse sex rstatprtx hla tbi gmgp agvhd
## 5 5 2013-08-27 2014-09-08 18.4 yes M CR2 0 yes BM/PB,0,no no
## cgvhd test bdate btime rtime rbin bmc_cdw bmc_cd3 bmc_cd15 bmc_cd34
## 5 yes D1MC 2013-09-24 28 377 yes 99 83 100 98
## pdate ptime pbc_cdw pbc_cd3 pbc_cd15 pbc_cd34
## 5 2013-09-24 28 96 77 98 93
for (i in patientID) {
explanation_caret <- explain(
x = all.df2[i,],
explainer = explainer_caret,
n_permutations = 5000,
dist_fun = "gower",
kernel_width = .75,
n_features = 10,
feature_select = "highest_weights",
labels = "yes"
)
p1 <- plot_features(explanation_caret) + ggtitle(paste("Patient", j))
print(p1)
}
Patient 15 (ID = 43)
j = 43
patientID <- which(all.df$ID == j)
## Print one line of table to check
all.df[patientID[1], ]
## ID dot dor txage relapse sex rstatprtx hla tbi gmgp
## 89 43 2018-08-10 2020-04-21 6.5 yes M CR2 5 yes PB,≤5, yes
## agvhd cgvhd test bdate btime rtime rbin bmc_cdw bmc_cd3 bmc_cd15
## 89 no no D2MC 2018-10-10 61 620 yes 100 100 99
## bmc_cd34 pdate ptime pbc_cdw pbc_cd3 pbc_cd15 pbc_cd34
## 89 100 2018-10-10 61 100 100 100 100
for (i in patientID) {
explanation_caret <- explain(
x = all.df2[i,],
explainer = explainer_caret,
n_permutations = 5000,
dist_fun = "gower",
kernel_width = .75,
n_features = 10,
feature_select = "highest_weights",
labels = "yes"
)
p1 <- plot_features(explanation_caret) + ggtitle(paste("Patient", j))
print(p1)
}
Patient 16 (ID = 8)
j = 8
patientID <- which(all.df$ID == j)
## Print one line of table to check
all.df[patientID[1], ]
## ID dot dor txage relapse sex rstatprtx hla tbi gmgp agvhd
## 8 8 2018-09-18 2019-12-10 15.8 yes M CR1 0 yes BM/PB,0,no yes
## cgvhd test bdate btime rtime rbin bmc_cdw bmc_cd3 bmc_cd15 bmc_cd34
## 8 no D1MC 2018-10-16 28 448 yes 100 99 100 100
## pdate ptime pbc_cdw pbc_cd3 pbc_cd15 pbc_cd34
## 8 2018-10-16 28 100 100 99 100
for (i in patientID) {
explanation_caret <- explain(
x = all.df2[i,],
explainer = explainer_caret,
n_permutations = 5000,
dist_fun = "gower",
kernel_width = .75,
n_features = 10,
feature_select = "highest_weights",
labels = "yes"
)
p1 <- plot_features(explanation_caret) + ggtitle(paste("Patient", j))
print(p1)
}
Patient 17 (ID = 45)
j = 45
patientID <- which(all.df$ID == j)
## Print one line of table to check
all.df[patientID[1], ]
## ID dot dor txage relapse sex rstatprtx hla tbi gmgp
## 45 45 2018-11-07 2019-04-29 9.9 yes M CR3 4 yes PB,≤5, yes
## agvhd cgvhd test bdate btime rtime rbin bmc_cdw bmc_cd3 bmc_cd15
## 45 no no D1MC 2018-12-06 29 173 yes 100 99 100
## bmc_cd34 pdate ptime pbc_cdw pbc_cd3 pbc_cd15 pbc_cd34
## 45 100 2018-11-23 16 100 100 100 100
for (i in patientID) {
explanation_caret <- explain(
x = all.df2[i,],
explainer = explainer_caret,
n_permutations = 5000,
dist_fun = "gower",
kernel_width = .75,
n_features = 10,
feature_select = "highest_weights",
labels = "yes"
)
p1 <- plot_features(explanation_caret) + ggtitle(paste("Patient", j))
print(p1)
}
Stanford Medicine, dcshyr@stanford.edu↩︎
University of Utah, simon.brewer@geog.utah.edu↩︎